Skip to content

[Feature] Hierarchical reduction and warp reduction intrinsics support#1762

Merged
LeiWang1999 merged 13 commits intotile-ai:mainfrom
tzj-fxz:redux
Feb 13, 2026
Merged

[Feature] Hierarchical reduction and warp reduction intrinsics support#1762
LeiWang1999 merged 13 commits intotile-ai:mainfrom
tzj-fxz:redux

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Jan 31, 2026

For #1761

  • Add hierarchical reduction from warp to block to reduce workspace size
  • Add redux.sync PTX templates to support faster reduction on (u)int32 with __CUDA_ARCH__>=800

Summary by CodeRabbit

  • Performance

    • Faster and more memory-efficient reductions on recent GPUs via hierarchical and architecture-aware reduction paths.
  • Bug Fixes

    • Improved handling and validation of integer reduction results to ensure correct casting and accurate outcomes.
  • Tests

    • Expanded test coverage exercising many reduction ops, dtypes and thread configurations with stronger reference comparisons.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 31, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Replaces Hopper-specific AllReduce paths with NamedBarrier for SM>=90, adds hierarchical per-warp AllReduce logic and dynamic workspace sizing when reducing_threads > 32, extends CUDA warp_reduce and barrier primitives, and generalizes reduce tests to exercise multiple ops, dtypes, and thread counts.

Changes

Cohort / File(s) Summary
AllReduce lowering & workspace
src/op/reduce.cc, src/op/finalize_reducer.cc
Remove Hopper-only run_hopper usage; invoke AllReduce with tl::NamedBarrier for targets with SM>=90; add logic to compute and pass a hierarchical workspace when reducing_threads > 32 (workspace sized as reducing_threads/32 when hierarchical).
CUDA reduction templates
src/tl_templates/cuda/reduce.h
Add warp_reduce declaration and implement barrier policies (SyncThreadsBarrier, NamedBarrier), hierarchical and butterfly AllReduce paths, warp_inter_reduce, arch-aware FP16/BF16 handling, shuffle/intrinsic optimizations, and remove the prior thread-count static_assert.
Tests — reduce coverage
testing/python/language/test_tilelang_language_reduce.py
Generalize reduce tests: add op and threads params, dispatch multiple ops (sum,max,min,abssum,absmax,bitand,bitor,bitxor), run across dtypes/shapes, and validate against Torch reference implementations.

Sequence Diagram(s)

sequenceDiagram
    participant Lane as Lane (thread lane)
    participant Warp as Warp (warp-level)
    participant Workspace as Workspace (per-warp buffer)
    participant AllReduce as AllReduce (reducer + barrier)

    Lane->>Warp: compute partial per-lane value
    Warp->>Warp: warp_reduce(value) (shuffle/intrinsics)
    Warp->>Workspace: write per-warp result (hierarchical path)
    AllReduce->>Workspace: read per-warp results
    AllReduce->>AllReduce: inter-warp reduction (with NamedBarrier sync)
    AllReduce-->>Lane: broadcast final reduced result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Poem

🐰 I hopped through lanes with eager feet,

I shuffled bits till outputs meet,
Per-warp nests arranged just right,
Named barriers kept sync tight,
Now reductions hum throughout the night.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 9.09% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly reflects the main changes: introducing hierarchical reduction and warp reduction intrinsics support, which are the primary features across all modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@testing/python/language/test_tilelang_language_reduce.py`:
- Around line 99-105: The ref_fn uses torch.uint32 in a static dtype list which
raises AttributeError on PyTorch <2.3.0; update ref_fn to conditionally include
torch.uint32 only when hasattr(torch, "uint32") (or use getattr with a fallback)
so the dtype check is built at runtime, mirroring the existing pattern used for
version-dependent dtypes like float8_e4m3fn; locate and modify the dtype
membership test in ref_fn to construct the list/set conditionally and then
perform the same res.to(A.dtype) conversion for the supported integer dtypes
(torch.uint32, torch.int32, torch.int64).
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_reduce.py (1)

7-7: Consider scoping disable_cache() to avoid global test side effects.
This flips a process-wide cache flag; if other tests run in the same session, they inherit the disabled cache. If that’s not intended, wrap it in a fixture/context that re-enables after this module.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Jan 31, 2026

Maybe we also need the performance regression tests. Where can I trigger them? @LeiWang1999

@Rachmanino
Copy link
Collaborator

@regression-perf

@bucket-xv
Copy link
Contributor

Thanks for your contribution! I've further investigated this problem this weekend and have some suggestions for the code:

  1. Use intrinsic functions instead of raw redux PTX for maintainability. Refer to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-reduce-functions. Note that this is also for sm80+.
  2. May leverage redux instruction for f32 types? This is supported since sm100a. Refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-redux-sync.
  3. At warp level reduction, maybe any floating types can be cast to f32 and integral types cast to int32. This cast helps leverage the redux inst. This is almost always better in performance since there is also an implicit cast for shfl.sync, which requires b32 types.(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-shfl-sync)

@LeiWang1999
Copy link
Member

@tzj-fxz Would you mind take a look.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 2, 2026

Thanks for your contribution! I've further investigated this problem this weekend and have some suggestions for the code:

  1. Use intrinsic functions instead of raw redux PTX for maintainability. Refer to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-reduce-functions. Note that this is also for sm80+.
  2. May leverage redux instruction for f32 types? This is supported since sm100a. Refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-redux-sync.
  3. At warp level reduction, maybe any floating types can be cast to f32 and integral types cast to int32. This cast helps leverage the redux inst. This is almost always better in performance since there is also an implicit cast for shfl.sync, which requires b32 types.(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-shfl-sync)

Thank you for the advice. I will further implement these features.

@tzj-fxz tzj-fxz changed the title [Feature] Hierarchical reduction and redux PTX support [Feature] Hierarchical reduction and warp reduction intrinsics support Feb 2, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@src/tl_templates/cuda/reduce.h`:
- Around line 377-381: The fast-path incorrectly casts 64-bit integrals to
int32_t and uses run_reduce_sync, truncating results; update the compile-time
branch in the template so the run_reduce_sync path is only taken when T is
exactly int32_t or uint32_t (remove the branch that treats other
std::is_integral_v<T> types by casting), and let 64-bit types fall through to
the shfl_xor_sync-based fallback that correctly handles 64-bit operands; adjust
the condition around run_reduce_sync (the constexpr if that currently checks
std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t> and the following
std::is_integral_v<T> branch) accordingly so no cast-to-int32_t is performed for
64-bit integrals.
- Around line 344-358: The sm100 fast-path incorrectly casts floats and calls
__reduce_max_sync/__reduce_min_sync which only accept 32-bit integer operands;
modify the branch in reduce.h so the __reduce_*_sync calls are only compiled for
32-bit integer types (e.g., when std::is_same_v<T,int> or
std::is_same_v<T,unsigned int> / std::is_integral and sizeof(T)==4), and remove
the float/half/bfloat casting there; for all other types (float, half_t,
bfloat16_t, etc.) fall through to the existing non-sm100 handling. Ensure you
static_cast the value to int/unsigned before passing to
__reduce_max_sync/__reduce_min_sync and keep the ReduceOp/MaxOp/MinOp constexpr
checks to select the right intrinsic.

@bucket-xv
Copy link
Contributor

bucket-xv commented Feb 3, 2026

As far as I know, __reduce_max_sync does not accept float types. May you use assembly to use redux.sync for floating types since sm_100a?

It's indeed complex. It seems that __reduce_max_sync can be used to substitute ASM, but no function warps redux.sync for floating types since sm_100a.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/tl_templates/cuda/reduce.h`:
- Around line 344-366: The SM100 asm fast-path currently casts any T to float
and uses redux.sync.*.f32, which corrupts large integer ordering; change the
compile-time guard inside the SM100 block to only take this fast-path for
floating types (float, half_t, bfloat16_t) — e.g. wrap the asm branches with an
if constexpr that checks std::is_floating_point_v<T> || std::is_same_v<T,
half_t> || std::is_same_v<T, bfloat16_t> so integral types fall through to the
SM80+ __reduce_max_sync/__reduce_min_sync implementations; keep the existing
value_cast and asm/result handling for the floating case and do not alter the
fallback path for non-floating T.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 4, 2026

As far as I know, __reduce_max_sync does not accept float types. May you use assembly to use redux.sync for floating types since sm_100a?

It's indeed complex. It seems that __reduce_max_sync can be used to substitute ASM, but no function warps redux.sync for floating types since sm_100a.

Fixed in the latest commit. :)

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better to have some benchmark results

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 5, 2026

@regression-perf

@github-actions
Copy link

github-actions bot commented Feb 5, 2026

Performance Regression Test Report

Triggered by: @tzj-fxz
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21701890076

Results

File Original Latency Current Latency Speedup
example_gemv 0.281301 0.29055 0.968167
example_warp_specialize_gemm_barrierpipe_stage2 0.038721 0.039393 0.982941
example_dequant_gemm_bf16_fp4_hopper 0.617419 0.625101 0.987711
example_gqa_decode 0.047905 0.048193 0.994024
example_dequant_gemm_fp4_hopper 1.03851 1.04228 0.996386
example_mha_sink_fwd_bhsd_sliding_window 0.0155414 0.0155875 0.997045
example_tilelang_gemm_fp8 0.317253 0.318151 0.997176
example_gemm_intrinsics 0.034593 0.034656 0.998182
example_warp_specialize_gemm_softpipe_stage2 0.038049 0.038113 0.998321
example_convolution_autotune 0.991021 0.992591 0.998418
example_tilelang_gemm_fp8_intrinsic 0.91042 0.91147 0.998848
example_tilelang_sparse_gqa_decode_varlen_indice 0.0168975 0.0169135 0.999054
example_mha_sink_fwd_bhsd 0.0157149 0.0157277 0.999182
example_mha_fwd_varlen 0.0449671 0.0449979 0.999316
example_mha_bwd_bhsd 0.0399926 0.0400117 0.999523
example_gqa_bwd 0.049014 0.0490372 0.999525
example_mha_sink_bwd_bhsd_sliding_window 0.0443848 0.0444048 0.999551
example_tilelang_gemm_splitk 1.4019 1.40252 0.999554
example_gemm_schedule 0.0322591 0.0322731 0.999566
example_mha_inference 0.079969 0.08 0.999612
fp8_lighting_indexer 0.0353686 0.0353812 0.999644
sparse_mla_fwd_pipelined 0.0946215 0.0946542 0.999655
example_group_per_split_token_cast_to_fp8 0.0103231 0.0103262 0.999698
sparse_mla_bwd 0.377131 0.377235 0.999726
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0144159 0.0144195 0.999749
sparse_mla_fwd 0.129611 0.129639 0.999779
example_mha_bwd_bshd_wgmma_pipelined 0.0254211 0.0254248 0.999856
example_tilelang_block_sparse_attn 0.0100668 0.0100682 0.999864
example_gqa_bwd_tma_reduce_varlen 0.0512859 0.051292 0.999882
tilelang_example_sparse_tensorcore 0.0149007 0.0149024 0.99989
example_linear_attn_bwd 0.152459 0.152465 0.999959
example_linear_attn_fwd 0.0365545 0.0365553 0.999978
example_mla_decode 0.449224 0.449226 0.999996
example_dequant_gemv_fp16xint4 0.0283622 0.0283618 1.00001
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.015343 0.0153427 1.00002
example_blocksparse_gemm 0.0224562 0.0224555 1.00003
example_elementwise_add 0.29402 0.29399 1.0001
example_tilelang_nsa_decode 0.00730636 0.00730504 1.00018
example_vertical_slash_sparse_attn 0.231701 0.231659 1.00018
example_convolution 1.30915 1.3088 1.00027
example_tilelang_sparse_gqa_decode_varlen_mask 0.023128 0.0231209 1.0003
example_gqa_sink_bwd_bhsd 0.0408243 0.0408075 1.00041
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40117 1.40047 1.0005
example_gqa_sink_bwd_bhsd_sliding_window 0.0251555 0.0251424 1.00052
example_gqa_bwd_wgmma_pipelined 0.0687511 0.0687154 1.00052
example_mha_bwd_bshd 0.0406295 0.0406056 1.00059
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0143039 0.0142944 1.00066
example_per_token_cast_to_fp8 0.00739874 0.00739373 1.00068
topk_selector 0.0531145 0.0530765 1.00072
example_dynamic 0.651212 0.650668 1.00084
example_tilelang_nsa_fwd 0.00681937 0.00681268 1.00098
block_sparse_attn_tilelang 0.0101621 0.0101517 1.00103
example_dequant_gemm_w4a8 5.30534 5.29934 1.00113
example_dequant_gemm_bf16_mxfp4_hopper 0.557416 0.556684 1.00131
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0152878 0.0152557 1.00211
example_mha_sink_bwd_bhsd 0.0615516 0.0614027 1.00242
example_tilelang_gemm_fp8_2xAcc 0.183363 0.182821 1.00297
example_gemm_autotune 0.022176 0.02208 1.00435
example_gemm 0.022752 0.022465 1.01278
example_warp_specialize_gemm_copy_0_gemm_1 0.038817 0.037985 1.0219
example_dequant_groupedgemm_bf16_mxfp4_hopper 4.10577 3.98573 1.03012
example_warp_specialize_gemm_copy_1_gemm_0 0.037473 0.036224 1.03448
example_topk 0.010752 0.009249 1.1625

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 5, 2026

Threads=128. One block.

Op M N Original (ms) Redux (ms) Original (tpt) Redux (tpt) Speedup
MAX 256 32 0.0077 0.0076 4.37 GB/s 4.46 GB/s +1.32%
MIN 256 32 0.0067 0.0062 5.05 GB/s 5.41 GB/s +7.13%
MAX 256 64 0.0068 0.0066 9.83 GB/s 10.09 GB/s +3.03%
MIN 256 64 0.0067 0.0063 9.90 GB/s 10.49 GB/s +6.35%
MAX 256 128 0.0067 0.0062 19.70 GB/s 21.31 GB/s +8.06%
MIN 256 128 0.0070 0.0065 18.88 GB/s 20.48 GB/s +7.69%

The regression performance test also shows that example_topk.py, which heavily uses reduce_max, has a 10%+ gain.

@tzj-fxz tzj-fxz requested a review from LeiWang1999 February 5, 2026 07:43
@tzj-fxz tzj-fxz requested a review from bucket-xv February 6, 2026 06:01
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@src/tl_templates/cuda/reduce.h`:
- Around line 344-365: The fast-path in the CUDA reduce code leaves value_cast
as 0.0f for T=float (and mishandles double), so the inline asm reduces zeros;
update the branches in the block inside reduce.h (the section that sets
value_cast and emits redux.sync) to handle float explicitly by assigning
value_cast = value when std::is_same_v<T, float>, keep the existing assignments
for half_t and bfloat16_t, and ensure the fast-path is enabled only for the
three safe types (float, half_t, bfloat16_t) — do NOT include double — before
emitting the MaxOp/MinOp asm so the asm uses the real value_cast.
- Around line 368-382: The template lambda run_reduce_sync uses C++20-only
syntax and must be replaced for C++17: implement a regular templated helper
function (or a small function object) named e.g. run_reduce_sync_impl templated
on typename T_cast that contains the same if constexpr branches for ReduceOp vs
SumOp/MaxOp/MinOp/BitAndOp/BitOrOp/BitXorOp and calls the corresponding
__reduce_*_sync(mask, val), then call that helper from the existing site instead
of the lambda; ensure the helper has the same capture visibility (pass mask in
as a parameter or make it accessible) and is declared in the same scope so
ReduceOp and the __reduce_*_sync calls resolve correctly.
🧹 Nitpick comments (1)
src/tl_templates/cuda/reduce.h (1)

130-189: run_hopper duplicates run except for barrier calls — consider unifying.

The hierarchical reduction logic (warp reduce → shared-memory exchange → butterfly) is copy-pasted between run and run_hopper, differing only in barrier synchronization (__syncthreads() vs bar.sync asm). A barrier-policy template parameter or a shared helper parameterized on a sync functor would eliminate this ~60-line duplication and reduce maintenance risk when the reduction logic is updated.

@bucket-xv
Copy link
Contributor

May you add some more tests? i.e:

  1. All datatypes and all types of reduce op. from fp4, fp8 to fp32 and int8 to int64. Also, max, min, add, absmax, absmin etc.
  2. All shapes and block sizes. The current ones do not involve num threads < 32.
    If that is all prepared, I have no further suggestions.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 7, 2026

  1. Add more tests for various datatypes and operations. Since we rarely do reductions on low-precision datatypes (warp shuffle instructions also do not support them), we skip the tests for them.
  2. Add tests for threads < 32.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@testing/python/language/test_tilelang_language_reduce.py`:
- Around line 46-61: The test's dispatch on op (variable op) lacks a final else
branch so unsupported op values leave B_local uninitialized and produce silent
garbage; update the branch in the block that calls T.reduce_sum / T.reduce_max /
T.reduce_min / T.reduce_abssum / T.reduce_absmax / T.reduce_bitand /
T.reduce_bitor / T.reduce_bitxor to include a final else that raises an explicit
exception (e.g., raise ValueError or AssertionError) including the invalid op
string so the test fails loudly and identifies the unsupported operation.
- Around line 98-111: ref_fn currently lacks branches for the bitwise ops
(bitand, bitor, bitxor) so res can be unassigned; add branches inside ref_fn to
handle "bitand", "bitor", and "bitxor" by performing a column-wise reduction
over dim=1 (looping over columns and applying the appropriate
torch.bitwise_and/bitwise_or/bitwise_xor) and for "bitand" initialize the
accumulator with an all-ones tensor of the same dtype/shape as a row (use
~torch.zeros_like(row) or equivalent) so types like torch.uint32/int32/int64
behave correctly; keep the existing dtype-preservation logic (the check using
A.dtype in [...]) and return the reduced tensor cast back when needed.
🧹 Nitpick comments (2)
testing/python/language/test_tilelang_language_reduce.py (2)

133-138: Bitwise reduce ops (bitand, bitor, bitxor) are dispatched in the kernel but never tested.

The kernel builder supports these ops (lines 56-61), yet test_reduce_other_op only covers ["max", "min", "abssum", "absmax"]. Adding at least one test per bitwise op (on integer dtypes) would catch regressions in the new warp-reduce paths. This also aligns with the reviewer request for broader operation coverage.


88-94: mode="ss" artificially restricted to op="sum" despite existing shared builders for other ops.

Shared-reduce builders already exist for max, min, abssum, and absmax (lines 71-84). The run_reduce function could dispatch to the right one instead of hard-coding reduce_sum_ss. This would allow the same run_reduce interface to cover shared-mode tests uniformly.

Sketch
     elif mode == "ss":
-        assert op == "sum", f"shared reduce only supports sum, got {op}"
-        program = reduce_sum_ss(M, N, dtype)
+        ss_builders = {
+            "sum": reduce_sum_ss,
+            "max": reduce_max_ss,
+            "min": reduce_min_ss,
+            "abssum": reduce_abssum_ss,
+            "absmax": reduce_absmax_ss,
+        }
+        if op not in ss_builders:
+            raise NotImplementedError(f"shared reduce not implemented for op={op}")
+        program = ss_builders[op](M, N, dtype)

@bucket-xv
Copy link
Contributor

I have no further concerns. Shall we test and merge it? @LeiWang1999

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 7, 2026

I don't have a ROCm machine on hand to reproduce the error. The Hopper passes this test. It seems weird because the test function test_reduce_max_clear and the source code are not modified in the latest commit. Could anyone help?

LeiWang1999 and others added 2 commits February 11, 2026 23:31
Replace the duplicated run()/run_hopper() methods in AllReduce with a
single run() that accepts a Barrier policy template parameter
(SyncThreadsBarrier or NamedBarrier<N>). Extract the shared inter-warp
reduction logic into a warp_inter_reduce helper, and split the dispatch
into private hierarchical_reduce/butterfly_reduce methods.

Update codegen in reduce.cc and finalize_reducer.cc to emit
NamedBarrier<all_threads> for SM >= 90 targets instead of the old
all_threads + run_hopper pattern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@src/op/reduce.cc`:
- Around line 392-408: The hierarchical-workspace sizing is wrong when
reducing_threads < all_threads: is_hierarchical currently makes workspace_size =
reducing_threads / 32, but threads use global_warp_id = (threadIdx.x -
thread_offset) / 32 which can index up to (all_threads-1)/32, causing
out-of-bounds writes in hierarchical_reduce; fix by sizing the workspace to
cover all warps in the block when is_hierarchical (use the block thread extent
from T.thread_bounds->extent, e.g. workspace_size =
static_cast<int>(*as_const_int(T.thread_bounds->extent) + 31) / 32 or equivalent
integer division rounding up), and keep using T.AddWorkspace and
thread_reduce_args as before (symbols to edit: is_hierarchical,
reducing_threads, T.thread_bounds->extent, T.AddWorkspace, hierarchical_reduce,
global_warp_id, thread_offset).

In `@src/tl_templates/cuda/reduce.h`:
- Around line 65-78: The comment notes that hierarchical_reduce now uses barrier
phases 1, 2 and 3, so barrier IDs 1–3 are reserved for internal use; update any
user-facing barrier documentation and comments that currently state "user
barriers start from 3" to instead state that user-defined barrier IDs must start
from 4, and add a short note near the barrier policy declarations
(SyncThreadsBarrier and NamedBarrier) and in the hierarchical_reduce
documentation mentioning that hierarchical_reduce occupies phases 1–3 to avoid
future sync conflicts.
- Around line 140-164: hierarchical_reduce can write past red_buf when the
caller uses fewer "reducing_threads" than the full block (global_warp_id may
exceed reducing_threads/32); fix by adding bounds checks or by having the caller
allocate red_buf sized to total global warps: ensure before writing
red_buf[global_warp_id] and before returning/reading red_buf[group_id *
num_warps_per_group] that global_warp_id and group_id*... are within the
allocated workspace (or clamp/skips writes for warps >= reducing_warps), and
verify warp_inter_reduce is only invoked for groups that have valid entries;
update the caller that allocates red_buf (and any code that computes
reducing_threads) to size it to ceil(reducing_threads/32) or to total global
warps consistently.
🧹 Nitpick comments (1)
src/op/finalize_reducer.cc (1)

113-117: Minor inconsistency: workspace allocation threshold is >= 32 here vs > 32 in reduce.cc.

In reduce.cc (line 401), workspace is only allocated when reducing_threads > 32, while here it uses >= 32. When reducing_threads == 32, AllReduce::run dispatches to warp_reduce which doesn't use red_buf, so the allocation is harmless but wastes shared memory. Consider aligning the threshold to > 32 for consistency.

♻️ Suggested fix
-  if (reducing_threads >= 32) {
+  if (reducing_threads > 32) {

src/op/reduce.cc Outdated
Comment on lines 392 to 408
bool is_hierarchical = [&]() {
if (reducing_threads <= 32)
return false;
if (reducing_threads % 32 != 0)
return false;
if (*scale != 1)
return false;
return true;
}();
if (reducing_threads > 32) {
PrimExpr workspace = T.AddWorkspace(
*as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
int workspace_size =
is_hierarchical
? reducing_threads / 32
: static_cast<int>(*as_const_int(T.thread_bounds->extent));
PrimExpr workspace =
T.AddWorkspace(workspace_size, clear_buffer->dtype);
thread_reduce_args.push_back(workspace);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: workspace undersized for hierarchical reduction when reducing_threads < all_threads.

In hierarchical_reduce (reduce.h, lines 146-153), global_warp_id = (threadIdx.x - thread_offset) / 32 ranges from 0 to (all_threads - 1) / 32. When reducing_threads < all_threads (e.g., the reduce axis only covers a subset of threads), the workspace is allocated as reducing_threads / 32 but threads in higher groups write to red_buf[global_warp_id] at indices beyond that size — an out-of-bounds shared memory write.

Example: all_threads=256, reducing_threads=64, scale=1workspace_size=2, but global_warp_id can reach 7.

The fix is to size the hierarchical workspace to cover all warps in the block:

🐛 Proposed fix
         int workspace_size =
             is_hierarchical
-                ? reducing_threads / 32
+                ? static_cast<int>(*as_const_int(T.thread_bounds->extent)) / 32
                 : static_cast<int>(*as_const_int(T.thread_bounds->extent));
🤖 Prompt for AI Agents
In `@src/op/reduce.cc` around lines 392 - 408, The hierarchical-workspace sizing
is wrong when reducing_threads < all_threads: is_hierarchical currently makes
workspace_size = reducing_threads / 32, but threads use global_warp_id =
(threadIdx.x - thread_offset) / 32 which can index up to (all_threads-1)/32,
causing out-of-bounds writes in hierarchical_reduce; fix by sizing the workspace
to cover all warps in the block when is_hierarchical (use the block thread
extent from T.thread_bounds->extent, e.g. workspace_size =
static_cast<int>(*as_const_int(T.thread_bounds->extent) + 31) / 32 or equivalent
integer division rounding up), and keep using T.AddWorkspace and
thread_reduce_args as before (symbols to edit: is_hierarchical,
reducing_threads, T.thread_bounds->extent, T.AddWorkspace, hierarchical_reduce,
global_warp_id, thread_offset).

Comment on lines +140 to +164
private:
template <typename T>
static TL_DEVICE T run_hopper(T x, T *red_buf = nullptr) {
static TL_DEVICE T hierarchical_reduce(T x, T *red_buf) {
x = warp_reduce<T>(x, Reducer());

constexpr int num_warps_per_group = threads / 32;
const int global_warp_id = (threadIdx.x - thread_offset) / 32;
const int group_id = (threadIdx.x - thread_offset) / threads;
const int warp_id_in_group = global_warp_id % num_warps_per_group;
const int lane_id = threadIdx.x % 32;

Barrier::template sync<1>();
if (lane_id == 0) {
red_buf[global_warp_id] = x;
}
Barrier::template sync<2>();

if (warp_id_in_group == 0) {
const int group_base_warp = group_id * num_warps_per_group;
warp_inter_reduce<Reducer, num_warps_per_group>(red_buf, group_base_warp,
lane_id);
}
Barrier::template sync<3>();
return red_buf[group_id * num_warps_per_group];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Hierarchical reduce: verify workspace sizing in the caller.

The implementation itself is correct — global_warp_id correctly indexes across all warps in the block, and the group-based structure properly partitions the reduction. However, see the comment on src/op/reduce.cc lines 392-408 regarding workspace sizing: when reducing_threads < all_threads, global_warp_id can exceed reducing_threads / 32, causing an OOB write into red_buf.

🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/reduce.h` around lines 140 - 164, hierarchical_reduce
can write past red_buf when the caller uses fewer "reducing_threads" than the
full block (global_warp_id may exceed reducing_threads/32); fix by adding bounds
checks or by having the caller allocate red_buf sized to total global warps:
ensure before writing red_buf[global_warp_id] and before returning/reading
red_buf[group_id * num_warps_per_group] that global_warp_id and group_id*... are
within the allocated workspace (or clamp/skips writes for warps >=
reducing_warps), and verify warp_inter_reduce is only invoked for groups that
have valid entries; update the caller that allocates red_buf (and any code that
computes reducing_threads) to size it to ceil(reducing_threads/32) or to total
global warps consistently.

This update introduces a new case in the warp_reduce function to cast non-float types to float using static_cast. This enhancement improves type flexibility and ensures compatibility with a broader range of input types during reduction operations.
@LeiWang1999
Copy link
Member

@regression-perf

@github-actions
Copy link

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21947041518

Results

File Original Latency Current Latency Speedup
example_gemv 0.281827 0.290947 0.968653
example_mha_bwd_bhsd 0.039922 0.0400004 0.998039
example_gqa_fwd_bshd_wgmma_pipelined 0.0551059 0.0551925 0.998431
example_warp_specialize_gemm_copy_0_gemm_1 0.0387049 0.0387508 0.998814
example_mha_fwd_bshd_wgmma_pipelined 0.0144947 0.0145097 0.998968
example_gqa_bwd_tma_reduce_varlen 0.0515544 0.0515956 0.999202
example_mha_fwd_varlen 0.0449643 0.0449962 0.999292
example_gqa_fwd_bshd 0.070843 0.0708917 0.999313
example_tilelang_gemm_fp8_2xAcc 0.184109 0.184205 0.999478
example_warp_specialize_gemm_softpipe_stage2 0.0382708 0.0382893 0.999518
example_gqa_bwd 0.0489986 0.0490172 0.99962
example_mha_bwd_bshd 0.040605 0.040619 0.999656
example_dynamic 0.651342 0.651558 0.999669
example_gemm 0.0227478 0.0227538 0.999737
example_tilelang_gemm_fp8_intrinsic 0.822462 0.822655 0.999766
example_tilelang_gemm_fp8 0.318699 0.318762 0.999803
example_gemm_intrinsics 0.0342416 0.0342475 0.999826
example_warp_specialize_gemm_copy_1_gemm_0 0.0383183 0.0383229 0.99988
example_linear_attn_bwd 0.151377 0.151394 0.999886
example_gqa_bwd_wgmma_pipelined 0.0687111 0.0687187 0.999889
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40111 1.40113 0.999988
example_mha_bwd_bshd_wgmma_pipelined 0.0254297 0.025429 1.00002
example_gemm_schedule 0.0322515 0.0322476 1.00012
example_gqa_decode 0.0478102 0.0478043 1.00012
example_tilelang_gemm_splitk 1.40167 1.40147 1.00015
example_gemm_autotune 0.0223578 0.0223545 1.00015
example_vertical_slash_sparse_attn 0.231686 0.23165 1.00016
example_fusedmoe_tilelang 0.131461 0.131435 1.0002
example_linear_attn_fwd 0.036545 0.0365355 1.00026
example_mha_inference 0.0790618 0.0790396 1.00028
example_elementwise_add 0.294022 0.293919 1.00035
example_dequant_gemm_w4a8 5.30339 5.30139 1.00038
tilelang_example_sparse_tensorcore 0.0149054 0.0148994 1.0004
block_sparse_attn_tilelang 0.0101616 0.0101569 1.00046
example_mha_fwd_bshd 0.0258285 0.0258142 1.00055
example_dequant_gemv_fp16xint4 0.0284035 0.0283845 1.00067
example_mha_fwd_bhsd 0.0110616 0.0110518 1.00089
example_per_token_cast_to_fp8 0.00740351 0.00739264 1.00147
example_warp_specialize_gemm_barrierpipe_stage2 0.0393415 0.0392825 1.0015
example_mha_fwd_bhsd_wgmma_pipelined 0.0141652 0.0141416 1.00167
example_group_per_split_token_cast_to_fp8 0.01035 0.0103182 1.00308
example_tilelang_nsa_decode 0.00733855 0.0073079 1.00419
example_convolution_autotune 0.995238 0.990434 1.00485
example_tilelang_nsa_fwd 0.00698125 0.00693979 1.00598
example_mha_sink_fwd_bhsd 0.0158483 0.0157431 1.00668
example_mha_sink_fwd_bhsd_sliding_window 0.0156954 0.0155763 1.00765
example_mha_sink_bwd_bhsd_sliding_window 0.0447323 0.0443747 1.00806
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.43398 3.40604 1.0082
example_tilelang_block_sparse_attn 0.010154 0.0100684 1.0085
example_tilelang_sparse_gqa_decode_varlen_indice 0.0170436 0.0168999 1.00851
sparse_mla_fwd_pipelined 0.0961369 0.0953235 1.00853
topk_selector 0.0533104 0.0528411 1.00888
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0154854 0.0153467 1.00904
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.014437 0.0142965 1.00982
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0145742 0.0144293 1.01004
fp8_lighting_indexer 0.0357196 0.0353504 1.01045
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0154387 0.0152723 1.01089
example_blocksparse_gemm 0.0226608 0.0224114 1.01113
sparse_mla_fwd 0.130805 0.129208 1.01237
example_tilelang_sparse_gqa_decode_varlen_mask 0.0234181 0.0231217 1.01282
example_mha_sink_bwd_bhsd 0.0623849 0.061392 1.01617
example_dequant_gemm_bf16_mxfp4_hopper 0.51039 0.501956 1.0168
example_gqa_sink_bwd_bhsd_sliding_window 0.0255779 0.0251493 1.01704
sparse_mla_bwd 0.383248 0.376408 1.01817
example_convolution 1.33371 1.30932 1.01863
example_dequant_gemm_bf16_fp4_hopper 0.576083 0.56536 1.01897
example_gqa_sink_bwd_bhsd 0.0416693 0.0408318 1.02051
example_dequant_gemm_fp4_hopper 1.05666 1.03524 1.02069
example_mla_decode 0.461502 0.44942 1.02688
example_topk 0.0108965 0.00928903 1.17305

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 13, 2026

@regression-perf

@github-actions
Copy link

Performance Regression Test Report

Triggered by: @tzj-fxz
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/21988270388

Results

File Original Latency Current Latency Speedup
example_mha_sink_bwd_bhsd_sliding_window 0.0442404 0.0443808 0.996837
example_mha_sink_fwd_bhsd 0.0156996 0.0157402 0.997422
example_mha_fwd_bshd_wgmma_pipelined 0.0144778 0.0145103 0.997765
example_warp_specialize_gemm_barrierpipe_stage2 0.0392991 0.0393866 0.997779
example_per_token_cast_to_fp8 0.00739217 0.00740277 0.998568
example_dequant_gemv_fp16xint4 0.0283699 0.0284025 0.998851
example_tilelang_sparse_gqa_decode_varlen_indice 0.016887 0.0169028 0.999061
example_gqa_decode 0.0477291 0.0477704 0.999135
example_gqa_sink_bwd_bhsd_sliding_window 0.0251327 0.0251529 0.999198
example_dequant_gemm_bf16_mxfp4_hopper 0.502695 0.503095 0.999205
example_dequant_gemm_bf16_fp4_hopper 0.565053 0.565444 0.999309
example_tilelang_nsa_decode 0.00730645 0.00731091 0.99939
example_vertical_slash_sparse_attn 0.231584 0.231698 0.99951
example_tilelang_block_sparse_attn 0.0100741 0.010079 0.999522
example_tilelang_sparse_gqa_decode_varlen_mask 0.0231301 0.0231406 0.999547
example_mha_bwd_bshd 0.0406082 0.0406257 0.999569
example_warp_specialize_gemm_copy_0_gemm_1 0.0387597 0.0387751 0.999604
example_elementwise_add 0.293965 0.29408 0.99961
example_gqa_bwd_wgmma_pipelined 0.0686063 0.0686303 0.999651
example_gqa_fwd_bshd 0.070853 0.0708703 0.999756
example_dynamic 0.651516 0.651635 0.999818
example_fusedmoe_tilelang 0.131387 0.131411 0.999819
example_linear_attn_bwd 0.151373 0.151395 0.999854
example_tilelang_nsa_fwd 0.00694343 0.00694438 0.999863
example_tilelang_gemm_fp8 0.31871 0.31875 0.999876
example_gemm_autotune 0.0223541 0.0223558 0.999926
topk_selector 0.0528543 0.0528578 0.999934
example_gemm_intrinsics 0.0342428 0.034245 0.999936
example_gemm_schedule 0.0322591 0.0322588 1.00001
example_mha_bwd_bhsd 0.0399894 0.039989 1.00001
example_dequant_gemm_w4a8 5.30347 5.3033 1.00003
example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0144198 0.014419 1.00005
tilelang_example_sparse_tensorcore 0.0148928 0.0148919 1.00006
example_mla_decode 0.449409 0.449378 1.00007
example_linear_attn_fwd 0.0365499 0.0365465 1.00009
example_gqa_fwd_bshd_wgmma_pipelined 0.0551344 0.0551291 1.0001
example_mha_fwd_varlen 0.0449725 0.0449676 1.00011
example_gemv 0.281833 0.2818 1.00012
example_mha_bwd_bshd_wgmma_pipelined 0.0254265 0.0254233 1.00013
example_convolution_autotune 0.990672 0.990543 1.00013
example_mha_sink_bwd_bhsd 0.0615336 0.0615242 1.00015
example_group_per_split_token_cast_to_fp8 0.010323 0.010321 1.0002
example_convolution 1.30966 1.3094 1.0002
example_mha_fwd_bshd 0.0258212 0.0258159 1.00021
example_tilelang_gemm_splitk 1.40172 1.40132 1.00029
example_mha_fwd_bhsd_wgmma_pipelined 0.0141813 0.0141756 1.0004
example_gqa_sink_bwd_bhsd 0.0408309 0.0408131 1.00044
example_gqa_bwd_tma_reduce_varlen 0.0515698 0.0515469 1.00044
sparse_mla_fwd 0.129173 0.129114 1.00045
example_gemm 0.022729 0.0227168 1.00053
example_tilelang_gemm_splitk_vectorize_atomicadd 1.40163 1.40081 1.00058
example_blocksparse_gemm 0.0224296 0.0224142 1.00069
example_warp_specialize_gemm_softpipe_stage2 0.0383179 0.0382905 1.00071
sparse_mla_fwd_pipelined 0.0953829 0.0953088 1.00078
block_sparse_attn_tilelang 0.010168 0.0101595 1.00084
example_mha_sink_fwd_bhsd_sliding_window 0.0155557 0.0155423 1.00086
example_gqa_bwd 0.0490328 0.0489855 1.00097
sparse_mla_bwd 0.376703 0.376322 1.00101
example_warp_specialize_gemm_copy_1_gemm_0 0.0382959 0.0382571 1.00101
example_tilelang_gemm_fp8_2xAcc 0.184324 0.184131 1.00105
example_mha_sink_fwd_bhsd_wgmma_pipelined 0.0152989 0.015281 1.00117
example_gqa_sink_fwd_bhsd_wgmma_pipelined 0.0143022 0.0142852 1.00119
fp8_lighting_indexer 0.0353941 0.0353423 1.00147
example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window 0.0153492 0.0153254 1.00155
example_mha_inference 0.0791075 0.0789733 1.0017
example_dequant_gemm_fp4_hopper 1.03679 1.03497 1.00177
example_mha_fwd_bhsd 0.011112 0.0110648 1.00426
example_tilelang_gemm_fp8_intrinsic 0.822592 0.814128 1.0104
example_dequant_groupedgemm_bf16_mxfp4_hopper 3.43956 3.37297 1.01974
example_topk 0.0108988 0.00929587 1.17244

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999 LeiWang1999 merged commit 48a8f4a into tile-ai:main Feb 13, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants